package signature

import (
	"errors"
	"fmt"
	"strconv"
	"time"
	"bytes"
	"crypto"
	"strings"
	"compress/zlib"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/sha256"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/asn1"
	"encoding/base64"
	"encoding/json"
	"encoding/pem"
	"io/ioutil"
	"math/big"
)

var (
	oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
	oidNamedCurveS256 = asn1.ObjectIdentifier{1, 3, 132, 0, 10}
	signReplace       = map[string]string{"+": "*", "/": "-", "=": "_"}
)

type pkcs struct {
	Version    		int
	Algorithm       pkix.AlgorithmIdentifier
	PrivateKey 		[]byte
}

type ecPublicKey struct {
	Raw       		asn1.RawContent
	Algorithm 		pkix.AlgorithmIdentifier
	PublicKey 		asn1.BitString
}

type ecPrivateKey struct {
	Version       	int
	PrivateKey    	[]byte
	NamedCurveOID 	asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
	PublicKey     	asn1.BitString        `asn1:"optional,explicit,tag:1"`
}

type ecdsaSignature struct {
	R, S *big.Int
}

type options struct {
	AccountType 	string		`json:"TLS.account_type"`
	Identifier		string 		`json:"TLS.identifier"`
	AppIDAt3rd		string		`json:"TLS.appid_at_3rd"`
	SdkAppID 		string		`json:"TLS.sdk_appid"`
	Expire			string		`json:"TLS.expire_after"`
	Version         string 		`json:"TLS.version"`
	Time 			string		`json:"TLS.time"`
	Sign            string 		`json:"TLS.sig"`
}

// Signature
type Signature struct {
	AppID       int
	PrivateKey  string
	PublicKey   string

	options 	options
}

func New(appID int) *Signature {
	opt := options{
		Version: 		"201512300000",
		AppIDAt3rd: 	"0",
		AccountType: 	"0",
		SdkAppID: 		fmt.Sprintf("%d", appID),
		Expire: 		fmt.Sprintf("%d", 3600 * 24 * 180),
		Time: 			fmt.Sprintf("%d", time.Now().Unix()),
	}

	return &Signature{AppID: appID, options: opt}
}

func (s *Signature) WithExpire(expire int) *Signature {
	s.options.Expire = fmt.Sprintf("%d", expire)
	return s
}

func (s *Signature) WithAccountType(accountType int) *Signature {
	s.options.AccountType = fmt.Sprintf("%d", accountType)
	return s
}

func (s *Signature) Verify(sign, publicKey string) (bool, error) {
	data, err := decode(sign)
	if err != nil {
		return false, err
	}

	reader, err := zlib.NewReader(bytes.NewReader(data))
	if err != nil {
		return false, err
	}

	data, err = ioutil.ReadAll(reader)
	if err != nil {
		return false, err
	}

	if err := json.Unmarshal(data, &s.options); err != nil {
		return false, err
	}

	if s.options.SdkAppID != strconv.Itoa(s.AppID) {
		return false, errors.New("appID not match")
	}

	ct, err := strconv.ParseInt(s.options.Time, 10, 64)
	if err != nil {
		return false, err
	}

	ex, err := strconv.ParseInt(s.options.Expire, 10, 64)
	if err != nil {
		return false, err
	}

	if ct + ex < time.Now().Unix() {
		return false, errors.New("expired")
	}

	txt := s.content()
	sha := sha256.Sum256([]byte(txt))

	signature, err := base64.StdEncoding.DecodeString(s.options.Sign)
	if err != nil {
		return false, err
	}

	var ec ecdsaSignature
	if _, err = asn1.Unmarshal(signature, &ec); err != nil {
		return false, err
	}

	pk, err := parsePublicKey(publicKey)
	if err != nil {
		return false, err
	}

	res := ecdsa.Verify(pk, sha[:], ec.R, ec.S)

	return res, nil
}

func (s *Signature) Generate(identifier, privateKey string) (string, error) {
	s.options.Identifier = identifier

	sign, err := s.sign(privateKey)
	if err != nil {
		return "", err
	}

	s.options.Sign = sign

	data, err := json.Marshal(s.options)
	if err != nil {
		return "", err
	}

	var b bytes.Buffer
	z := zlib.NewWriter(&b)
	z.Write(data)
	z.Close()

	return encode(b.Bytes()), nil
}

func (s *Signature) sign(privateKey string) (string, error) {
	pk, err := parsePrivateKey(privateKey)
	if err != nil {
		return "", err
	}

	txt := s.content()
	sha := sha256.Sum256([]byte(txt))

	sign, err := pk.Sign(rand.Reader, sha[:], crypto.SHA256)
	if err != nil {
		return "", err
	}

	return base64.StdEncoding.EncodeToString(sign), nil
}

func (s *Signature) content() string {
	var b strings.Builder

	b.WriteString("TLS.appid_at_3rd:")
	b.WriteString(s.options.AppIDAt3rd)
	b.WriteString("\n")

	b.WriteString("TLS.account_type:")
	b.WriteString(s.options.AccountType)
	b.WriteString("\n")

	b.WriteString("TLS.identifier:")
	b.WriteString(s.options.Identifier)
	b.WriteString("\n")

	b.WriteString("TLS.sdk_appid:")
	b.WriteString(s.options.SdkAppID)
	b.WriteString("\n")

	b.WriteString("TLS.time:")
	b.WriteString(s.options.Time)
	b.WriteString("\n")

	b.WriteString("TLS.expire_after:")
	b.WriteString(s.options.Expire)
	b.WriteString("\n")

	return b.String()
}

func parsePublicKey(publicKey string) (*ecdsa.PublicKey, error) {
	block, _ := pem.Decode([]byte(publicKey))
	if block == nil {
		return nil, errors.New("invalid publicKey")
	}

	pk, err  := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		if strings.Contains(err.Error(), "unsupported elliptic curve") {
			var p ecPublicKey
			if _, err := asn1.Unmarshal(block.Bytes, &p); err != nil {
				return nil, err
			}

			data := p.PublicKey.RightAlign()
			params := p.Algorithm.Parameters.FullBytes
			id := new(asn1.ObjectIdentifier)
			_, err = asn1.Unmarshal(params, id)
			if err != nil {
				return nil, err
			}

			if id.Equal(oidNamedCurveS256) {
				k := new(ecdsa.PublicKey)
				k.Curve = S256()
				k.X, k.Y = elliptic.Unmarshal(k.Curve, data)

				return k, nil
			}
		}

		return nil, err
	}

	if v, ok := pk.(*ecdsa.PublicKey); ok {
		return v, nil
	} else {
		return nil, errors.New("invalid publicKey")
	}
}

func parsePrivateKey(privateKey string) (*ecdsa.PrivateKey, error) {
	block, _ := pem.Decode([]byte(privateKey))
	if block == nil {
		return nil, errors.New("invalid privateKey")
	}

	pk, err  := x509.ParsePKCS8PrivateKey(block.Bytes)
	if err != nil {
		if strings.Contains(err.Error(), "unknown elliptic curve") {
			var p pkcs

			if _, err := asn1.Unmarshal(block.Bytes, &p); err != nil {
				return nil, err
			}

			if p.Algorithm.Algorithm.Equal(oidPublicKeyECDSA) {
				id := new(asn1.ObjectIdentifier)
				asn1.Unmarshal(p.Algorithm.Parameters.FullBytes, id)
				if id.Equal(oidNamedCurveS256) {
					var ecPk ecPrivateKey
					asn1.Unmarshal(p.PrivateKey, &ecPk)

					k := new(ecdsa.PrivateKey)
					k.Curve = S256()
					d := new(big.Int)
					d.SetBytes(ecPk.PrivateKey)
					k.D = d
					k.X, k.Y = S256().ScalarBaseMult(d.Bytes())

					return k, nil
				}
			}
		}

		return nil, err
	}

	if v, ok := pk.(*ecdsa.PrivateKey); ok {
		return v, nil
	} else {
		return nil, errors.New("invalid privateKey")
	}
}

func encode(data []byte) string {
	res := base64.StdEncoding.EncodeToString(data)
	for k, v := range signReplace {
		res = strings.Replace(res, k, v, -1)
	}
	return res
}

func decode(data string) ([]byte, error) {
	for k, v := range signReplace {
		data = strings.Replace(data, v, k, -1)
	}
	return base64.StdEncoding.DecodeString(data)
}